#include "POF/c_api.h"

inline int POF_APIHandleException(const std::exception& ex) {
  POF_SetLastError(ex.what());
  return -1;
}
inline int POF_APIHandleException(const std::string& ex) {
  POF_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return POF_APIHandleException(ex); } \
catch(std::string& ex) { return POF_APIHandleException(ex); } \
catch(...) { return POF_APIHandleException("unknown exception"); } \
return 0;

// #define API_BEGIN()
// #define API_END() return 0;

const char* POF_GetLastError() {
  return LastErrorMsg();
}

int POF_init_userparameter(struct User_parameter** p_user_parameter){
    API_BEGIN();
    init_userparameter(p_user_parameter);
    API_END();
}
int POF_print_userparameter(struct User_parameter* user_parameter){
    API_BEGIN();
    print_userparameter(user_parameter);
    API_END();
}
int POF_set_userparameter(
	int n_rank,
	double th,
	double beta,
    double min_loss,
	double lambda,
	double xi,
	double context_list[],
	struct User_parameter* user_parameter
){
    API_BEGIN();
    set_userparameter(
        n_rank,
        th,
        beta,
        min_loss,
        lambda,
        xi,
        context_list,
        user_parameter
    );
    API_END();
}
int POF_update_xi(
    double xi,
    struct User_parameter* user_parameter
){
    API_BEGIN();
    update_xi(
        xi,
        user_parameter
    );
    API_END();
}
int POF_update_min_loss(
    double min_loss,
    struct User_parameter* user_parameter
){
    API_BEGIN();
    update_min_loss(
        min_loss,
        user_parameter
    );
    API_END();
}
int POF_set_w_cnt(
    double w_cnt
){
    API_BEGIN();
    //set_w_cnt(w_cnt);
    API_END();
}

int POF_init_nndata(struct NN_data** p_nndata){
    API_BEGIN();
    init_nndata(p_nndata);
    API_END();
}
int POF_print_nndata(struct NN_data* nn_data){
    API_BEGIN();
    print_nndata(nn_data);
    API_END();
}
double POF_get_bias(struct NN_data* nn_data){
    cout<<"Bias: "<<nn_data->bias<<endl;
    return nn_data->bias;
}
int POF_update_train_nndata(int batch_size, int n_class, double objective_list[], double ccon_w_list[], struct NN_data* nn_data){
    API_BEGIN();
    update_train_nndata(
        batch_size,
        n_class,
        objective_list,
        ccon_w_list,
        nn_data
    );
    API_END();
}
int POF_update_intest_nndata(int n_internal_test_case, int n_class, int ans_list[], double test_case_ccon_w_list[], struct NN_data* nn_data){
    API_BEGIN();
    update_intest_nndata(
        n_internal_test_case,
        n_class,
        ans_list,
        test_case_ccon_w_list,
        nn_data
    );
    API_END();
}
int POF_update_nndata(int batch_size, int n_internal_test_case, int n_class, double objective_list[], int ans_list[], double ccon_w_list[], double test_case_ccon_w_list[], struct NN_data* nn_data){
    API_BEGIN();
    update_nndata(
        batch_size,
        n_internal_test_case,
        n_class,
        objective_list,
        ans_list,
        ccon_w_list,
        test_case_ccon_w_list,
        nn_data
    );
    API_END();
}

int POF_init_indata(struct Indata** p_indata){
    API_BEGIN();
    init_indata(p_indata);
    API_END();
}
int POF_print_indata(struct Indata* indata){
    API_BEGIN();
    print_indata(indata);
    API_END();
}
double POF_get_posterior(struct Indata* indata, int i, int j, int k){
    // API_BEGIN();
    return get_posterior(indata, i, j, k);
    // API_END();
}
int POF_set_bias(int mode, struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata){
    API_BEGIN();
    set_bias(mode, nn_data, user_parameter, indata);
    API_END();
}
int POF_update_indata(int mode, struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata){
    API_BEGIN();
    update_indata(mode, nn_data, user_parameter, indata);
    API_END();
}

int POF_get_solution(NN_data* nn_data, Indata* indata, User_parameter* user_parameter){
    return get_solution(nn_data, indata, user_parameter);
}
double POF_calc_approx_loss(NN_data* nn_data, Indata* indata, User_parameter* user_parameter){
    // API_BEGIN();
    return calc_approx_loss(nn_data, indata, user_parameter);
    // API_END();
}

int POF_calc_grad_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[]){
    API_BEGIN();
    calc_grad_ccon(nn_data, user_parameter, indata, ccon_gradient_list);
    API_END();
}
int POF_calc_grad_test_case_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double test_case_ccon_gradient_list[]){
    API_BEGIN();
    calc_grad_test_case_ccon(nn_data, user_parameter, indata, test_case_ccon_gradient_list);
    API_END();
}
int POF_estimate_ccon_grad(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[], double ccon_gradient_square_list[], int n_sample){
    API_BEGIN();
    estimate_ccon_grad(nn_data, user_parameter, indata, ccon_gradient_list, ccon_gradient_square_list, n_sample);
    API_END();
}
int POF_estimate_test_case_ccon_grad(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double test_case_ccon_gradient_list[], double test_case_ccon_gradient_square_list[], int n_sample){
    API_BEGIN();
    estimate_test_case_ccon_grad(nn_data, user_parameter, indata, test_case_ccon_gradient_list, test_case_ccon_gradient_square_list, n_sample);
    API_END();
}
int POF_estimate_grad(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[], double ccon_gradient_square_list[], double test_case_ccon_gradient_list[], double test_case_ccon_gradient_square_list[], int n_sample){
    API_BEGIN();
    estimate_grad(nn_data, user_parameter, indata, ccon_gradient_list, ccon_gradient_square_list, test_case_ccon_gradient_list, test_case_ccon_gradient_square_list, n_sample);
    API_END();
}
int POF_fullcalc_grad_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[]){
    API_BEGIN();
    fullcalc_grad_ccon(nn_data, user_parameter, indata, ccon_gradient_list);
    API_END();
}
int POF_fullcalc_grad_ccon_sample_grad(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[], double clip){
    API_BEGIN();
    fullcalc_grad_ccon_sample_grad(nn_data, user_parameter, indata, ccon_gradient_list, clip);
    API_END();
}
int POF_fullcalc_grad_test_case_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double test_case_ccon_gradient_list[]){
    API_BEGIN();
    fullcalc_grad_test_case_ccon(nn_data, user_parameter, indata, test_case_ccon_gradient_list);
    API_END();
}
int POF_fullcalc_par20_grad_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[]){
    API_BEGIN();
    fullcalc_par20_grad_ccon(nn_data, user_parameter, indata, ccon_gradient_list);
    API_END();
}
int POF_fullcalc_par20_grad_test_case_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double test_case_ccon_gradient_list[]){
    API_BEGIN();
    fullcalc_par20_grad_test_case_ccon(nn_data, user_parameter, indata, test_case_ccon_gradient_list);
    API_END();
}
int POF_fullcalc_approx_grad_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[]){
    API_BEGIN();
    fullcalc_approx_grad_ccon(nn_data, user_parameter, indata, ccon_gradient_list);
    API_END();
}
int POF_fullcalc_approx_grad_test_case_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double test_case_ccon_gradient_list[]){
    API_BEGIN();
    fullcalc_approx_grad_test_case_ccon(nn_data, user_parameter, indata, test_case_ccon_gradient_list);
    API_END();
}

